0d0169
@@ -16,118 +16,156 @@
 
 package org.springframework.security.oauth.provider.nonce;
 
-import org.springframework.security.core.AuthenticationException;
-import org.springframework.security.oauth.provider.ConsumerDetails;
+import java.util.Iterator;
+import java.util.TreeSet;
 
-import java.util.concurrent.ConcurrentMap;
-import java.util.concurrent.ConcurrentHashMap;
-import java.util.*;
+import org.springframework.security.authentication.CredentialsExpiredException;
+import org.springframework.security.oauth.provider.ConsumerDetails;
 
 /**
- * Expands on the {@link org.springframework.security.oauth.provider.nonce.ExpiringTimestampNonceServices} to
- * include validation of the nonce for replay protection.<br/><br/>
+ * Expands on the {@link org.springframework.security.oauth.provider.nonce.ExpiringTimestampNonceServices} to include
+ * validation of the nonce for replay protection.<br/>
+ * <br/>
  *
  * To validate the nonce, the InMemoryNonceService first validates the consumer key and timestamp as does the
- * {@link org.springframework.security.oauth.provider.nonce.ExpiringTimestampNonceServices}.  Assuming the consumer
- * and timestamp are valid, the InMemoryNonceServices further ensures that the specified nonce was not used with the
- * specified timestamp within the specified validity window.  The list of nonces used within the validity window
- * is kept in memory.<br/><br/>
+ * {@link org.springframework.security.oauth.provider.nonce.ExpiringTimestampNonceServices}. Assuming the consumer and
+ * timestamp are valid, the InMemoryNonceServices further ensures that the specified nonce was not used with the
+ * specified timestamp within the specified validity window. The list of nonces used within the validity window is kept
+ * in memory.
+ *
+ * Note: the default validity window in this class is different from the one used in
+ * {@link org.springframework.security.oauth.provider.nonce.ExpiringTimestampNonceServices}. The reason for this is that
+ * this class has a per request memory overhead. Keeping the validity window short helps prevent wasting a lot of
+ * memory. 10 minutes that allows for minor variations in time between servers.
  *
  * @author Ryan Heaton
+ * @author Jilles van Gurp
  */
-public class InMemoryNonceServices extends ExpiringTimestampNonceServices {
-
-  protected static final ConcurrentMap<String, LinkedList<TimestampEntry>> TIMESTAMP_ENTRIES = new ConcurrentHashMap<String, LinkedList<TimestampEntry>>();
-
-  @Override
-  public void validateNonce(ConsumerDetails consumerDetails, long timestamp, String nonce) throws AuthenticationException {
-    final long cutoff = (System.currentTimeMillis() / 1000) - getValidityWindowSeconds();
-    super.validateNonce(consumerDetails, timestamp, nonce);
-
-    String consumerKey = consumerDetails.getConsumerKey();
-    LinkedList<TimestampEntry> entries = TIMESTAMP_ENTRIES.get(consumerKey);
-    if (entries == null) {
-      entries = new LinkedList<TimestampEntry>();
-      TIMESTAMP_ENTRIES.put(consumerKey, entries);
-    }
-
-    synchronized (entries) {
-      if (entries.isEmpty()) {
-        entries.add(new TimestampEntry(timestamp, nonce));
-      }
-      else {
-        boolean isNew = entries.getLast().getTimestamp() < timestamp;
-        ListIterator<TimestampEntry> listIterator = entries.listIterator();
-        while (listIterator.hasNext()) {
-          TimestampEntry entry = listIterator.next();
-          if (entry.getTimestamp() < cutoff) {
-            listIterator.remove();
-            isNew = !listIterator.hasNext();
-          }
-          else if (isNew) {
-            //optimize for a new, latest timestamp
-            entries.addLast(new TimestampEntry(timestamp, nonce));
-            return;
-          }
-          else if (entry.getTimestamp() == timestamp) {
-            if (!entry.addNonce(nonce)) {
-              throw new NonceAlreadyUsedException("Nonce already used: " + nonce);
-            }
-            return;
-          }
-          else if (entry.getTimestamp() > timestamp) {
-            //insert a new entry just before this one.
-            entries.add(listIterator.previousIndex(), new TimestampEntry(timestamp, nonce));
-            return;
-          }
-        }
-
-        //got through the whole list; assume it's just a new one.
-        //this shouldn't happen because of the optimization above.
-        entries.addLast(new TimestampEntry(timestamp, nonce));
-      }
-    }
-  }
-
-  protected static class TimestampEntry {
-
-    private final Long timestamp;
-    private final Set<String> nonces = new HashSet<String>();
-
-    public TimestampEntry(long timestamp, String firstNonce) {
-      this.timestamp = timestamp;
-      this.nonces.add(firstNonce);
-    }
-
-    /**
-     * Adds a nonce to this timestamp entry.
-     *
-     * @param nonce The nonce to add.
-     * @return The nonce.
-     */
-    public boolean addNonce(String nonce) {
-      synchronized (nonces) {
-        return nonces.add(nonce);
-      }
-    }
-
-    /**
-     * Get the timestamp for this entry.
-     *
-     * @return The timestamp for this entry.
-     */
-    public Long getTimestamp() {
-      return timestamp;
-    }
-
-    @Override
-    public int hashCode() {
-      return timestamp.hashCode();
-    }
-
-    @Override
-    public boolean equals(Object obj) {
-      return obj instanceof TimestampEntry && this.timestamp.equals(((TimestampEntry) obj).timestamp);
-    }
-  }
+public class InMemoryNonceServices implements OAuthNonceServices {
+
+	/**
+	 * Contains all the nonces that were used inside the validity window.
+	 */
+	static final TreeSet<NonceEntry> NONCES = new TreeSet<NonceEntry>();
+
+	private volatile long lastCleaned = 0;
+
+	// we'll default to a 10 minute validity window, otherwise the amount of memory used on NONCES can get quite large.
+	private long validityWindowSeconds = 60 * 10;
+
+	public void validateNonce(ConsumerDetails consumerDetails, long timestamp, String nonce) {
+		if (System.currentTimeMillis() / 1000 - timestamp > getValidityWindowSeconds()) {
+			throw new CredentialsExpiredException("Expired timestamp.");
+		}
+
+		NonceEntry entry = new NonceEntry(consumerDetails.getConsumerKey(), timestamp, nonce);
+
+		synchronized (NONCES) {
+			if (NONCES.contains(entry)) {
+				throw new NonceAlreadyUsedException("Nonce already used: " + nonce);
+			}
+			else {
+				NONCES.add(entry);
+			}
+			cleanupNonces();
+		}
+	}
+
+	private void cleanupNonces() {
+		long now = System.currentTimeMillis() / 1000;
+		// don't clean out the NONCES for each request, this would cause the service to be constantly locked on this
+		// loop under load. One second is small enough that cleaning up does not become too expensive.
+		// Also see SECOAUTH-180 for reasons this class was refactored.
+		if (now - lastCleaned > 1) {
+			Iterator<NonceEntry> iterator = NONCES.iterator();
+			while (iterator.hasNext()) {
+				// the nonces are already sorted, so simply iterate and remove until the first nonce within the validity
+				// window.
+				NonceEntry nextNonce = iterator.next();
+				long difference = now - nextNonce.timestamp;
+				if (difference > getValidityWindowSeconds()) {
+					iterator.remove();
+				}
+				else {
+					break;
+				}
+			}
+			// keep track of when cleanupNonces last ran
+			lastCleaned = now;
+		}
+	}
+
+	/**
+	 * Set the timestamp validity window (in seconds).
+	 *
+	 * @return the timestamp validity window (in seconds).
+	 */
+	public long getValidityWindowSeconds() {
+		return validityWindowSeconds;
+	}
+
+	/**
+	 * The timestamp validity window (in seconds).
+	 *
+	 * @param validityWindowSeconds the timestamp validity window (in seconds).
+	 */
+	public void setValidityWindowSeconds(long validityWindowSeconds) {
+		this.validityWindowSeconds = validityWindowSeconds;
+	}
+
+	/**
+	 * Representation of a nonce with the right hashCode, equals, and compareTo methods for the TreeSet approach above
+	 * to work.
+	 */
+	static class NonceEntry implements Comparable<NonceEntry> {
+		private final String consumerKey;
+
+		private final long timestamp;
+
+		private final String nonce;
+
+		public NonceEntry(String consumerKey, long timestamp, String nonce) {
+			this.consumerKey = consumerKey;
+			this.timestamp = timestamp;
+			this.nonce = nonce;
+		}
+
+		@Override
+		public int hashCode() {
+			return consumerKey.hashCode() * nonce.hashCode() * Long.valueOf(timestamp).hashCode();
+		}
+
+		@Override
+		public boolean equals(Object obj) {
+			if (obj == null || !(obj instanceof NonceEntry)) {
+				return false;
+			}
+			NonceEntry arg = (NonceEntry) obj;
+			return timestamp == arg.timestamp && consumerKey.equals(arg.consumerKey) && nonce.equals(arg.nonce);
+		}
+
+		public int compareTo(NonceEntry o) {
+			// sort by timestamp
+			if (timestamp < o.timestamp) {
+				return -1;
+			}
+			else if (timestamp == o.timestamp) {
+				int consumerKeyCompare = consumerKey.compareTo(o.consumerKey);
+				if (consumerKeyCompare == 0) {
+					return nonce.compareTo(o.nonce);
+				}
+				else {
+					return consumerKeyCompare;
+				}
+			}
+			else {
+				return 1;
+			}
+		}
+
+		@Override
+		public String toString() {
+			return timestamp + " " + consumerKey + " " + nonce;
+		}
+	}
 }
